import os
import json
import inspect
import argparse
import pandas as pd
from tqdm.auto import tqdm
from datasets import load_dataset
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from _models.model import get_embedding_func_batched
from _datasets.data import DatasetConfig
from utils.transform_utils import *
from utils.string_utils import *
from utils.metrics import *


class RetrievalExperimentConfig:

    def __init__(
        self,
        dataset_name: str,
        num_examples: int,
        model_name: str = "BAAI/bge-small-en-v1.5",
        max_length: int = 8192,
    ):
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.max_length = max_length

        self.load_data(dataset_name)

        print(f"Dataset {dataset_name} loaded.")

        self.embedding_func = get_embedding_func_batched(model_name)
        self.results = {}

        # Create directory for model data if it doesn't exist
        self.model_data_path = os.path.join(
            "data", self.model_name.replace("/", "_")
        )  # Replacing '/' with '_' to avoid subdirectories
        os.makedirs(self.model_data_path, exist_ok=True)

    def load_data(self):
        corpus = load_dataset(f"mteb/{dataset_name}", "corpus")
        queries = load_dataset(f"mteb/{dataset_name}", "queries")
        default = load_dataset(f"mteb/{dataset_name}", "default")

        default_df = default["test"].to_pandas()
        queries_df = queries["queries"].to_pandas()

        corpus_df = (
            corpus["corpus"]
            .filter(lambda x: x["_id"] in default_df["corpus-id"].values)
            .to_pandas()
        )

        corpus_df["text"] = corpus_df["title"] + ". " + corpus_df["text"]
        default_df = default_df.join(queries_df.set_index("_id"), on="query-id").rename(
            columns={"text": "query-text"}
        )

        default_df = default_df.join(corpus_df.set_index("_id"), on="corpus-id").rename(
            columns={"text": "corpus-text", "title": "corpus-title"}
        )

        self.corpus = (
            corpus_df["text"].apply(lambda x: truncate(x, self.max_length)).to_list()
        )
        self.corpus_df = corpus_df
        self.default_df = default_df
        self.queries = queries_df

    def run(self):
        self.augment_data()
        print("Augmented data.")

        self.generate_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated embeddings.")

        self.calculate_similarities()
        print("Calculated similarities.")

        self.fit_ensembling()
        print("Fitted ensembling.")

        # Save the results to a JSON file in the model-specific directory
        results_file_path = (
            f"{self.model_data_path}/{self.dataset_config.name}_retrieval.json"
        )
        with open(results_file_path, "w") as f:
            self.results = {
                k1: {k2: float(v2) for k2, v2 in v1.items()}
                for k1, v1 in self.results.items()
            }
            f.write(json.dumps(self.results))
        print(f"Saved results to {results_file_path}.")

    def augment_data(self):
        self.augmented_corpus = []
        for i in range(len(self.corpus)):
            self.augmented_corpus.extend(
                self.corpus[i],
                shuffle_text(self.corpus[i]),
                shuffle_words(self.corpus[i]),
                prune_text(self.corpus[i]),
                capitalize_random(self.corpus[i]),
                attack_text(self.corpus[i]),
                numerize_text(self.corpus[i]),
                negate_text(self.corpus[i]),
            )

    def generate_embeddings(self, embedding_func, **kwargs):
        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        embeds = embedding_func(prompts=self.augmented_corpus, **kwargs)
        self.embeddings_augmented_corpus = (
            embeds if isinstance(embeds, list) else embeds.tolist()
        )

        corpus_embeds = embedding_func(prompts=self.corpus, **kwargs)
        self.embeddings_corpus = (
            corpus_embeds if isinstance(corpus_embeds, list) else corpus_embeds.tolist()
        )

        query_embeds = embedding_func(prompts=self.queries["text"].values, **kwargs)
        self.embeddings_queries = (
            query_embeds if isinstance(query_embeds, list) else query_embeds.tolist()
        )

    def score_sims(self, sims, scores, corpus):
        rel_idxs = sims.argsort()[-len(scores) :]

        score = 0
        for j in range(len(rel_idxs)):
            text = corpus[rel_idxs[j]]
            if text not in self.corpus_df["text"].values:
                continue
            text_id = self.corpus_df[self.corpus_df["text"] == text]["_id"].iloc[0]

            if text_id in scores["corpus-id"].values:
                score += scores[scores["corpus-id"] == text_id]["score"].iloc[0]

        return score

    def calculate_score(self, corpus, corpus_embeddings):
        cosines, levenshteins, rouges, bm25s, jaccards = [], [], [], [], []
        for i in range(len(self.queries)):
            query = self.queries["text"][i]
            query_repeated = [query] * len(corpus)
            scores = self.default_df[
                self.default_df["query-id"] == self.queries["_id"][i]
            ]
            query_embedding = torch.tensor(self.embeddings_queries[i]).unsqueeze(0)

            cosine_sims = cosine_similarity(query_embedding, corpus_embeddings)
            levenshtein_sims = levenshtein_ratio(query_repeated, corpus)
            rouge_sims = rouge_score(query_repeated, corpus)
            bm25_sims = bm25_score(query_repeated, corpus)
            jaccard_sims = jaccard_similarity(query_repeated, corpus)

            cosines.append(self.score_sims(cosine_sims, scores, corpus))
            levenshteins.append(self.score_sims(levenshtein_sims, scores, corpus))
            rouges.append(self.score_sims(rouge_sims, scores, corpus))
            bm25s.append(self.score_sims(bm25_sims, scores, corpus))
            jaccards.append(self.score_sims(jaccard_sims, scores, corpus))

        return {
            "cosine": np.mean(cosines),
            "levenshtein": np.mean(levenshteins),
            "rouge": np.mean(rouges),
            "bm25": np.mean(bm25s),
            "jaccard": np.mean(jaccards),
        }

    def calculate_similarities(self):
        pre_augmentation_scores = self.calculate_score(
            self.corpus, self.embeddings_corpus
        )
        post_augmentation_scores = self.calculate_score(
            self.augmented_corpus, self.embeddings_augmented_corpus
        )
        for k, v in post_augmentation_scores.items():
            self.results[k] = post_augmentation_scores[k] / pre_augmentation_scores[k]

    def get_ensembled_scores(self, corpus, corpus_embeddings):
        queries, texts, scores = [], [], []
        cosines, levenshteins, rouges, bm25s, jaccards = [], [], [], [], []
        for i in range(len(self.queries)):
            query = self.queries["text"][i]
            query_repeated = [query] * len(corpus)
            scores = self.default_df[
                self.default_df["query-id"] == self.queries["_id"][i]
            ]
            query_embedding = torch.tensor(self.embeddings_queries[i]).unsqueeze(0)

            cosine_sims = cosine_similarity(query_embedding, corpus_embeddings)
            levenshtein_sims = levenshtein_ratio(query_repeated, corpus)
            rouge_sims = rouge_score(query_repeated, corpus)
            bm25_sims = bm25_score(query_repeated, corpus)
            jaccard_sims = jaccard_similarity(query_repeated, corpus)

            cosines.extend(cosine_sims)
            levenshteins.extend(levenshtein_sims)
            rouges.extend(rouge_sims)
            bm25s.extend(bm25_sims)
            jaccards.extend(jaccard_sims)

            queries.extend(query_repeated)
            texts.extend(corpus)
            scores.extend(
                [
                    1 if corpus[i] in scores["corpus-text"].values else 0
                    for i in range(len(corpus))
                ]
            )

        X = np.array([cosines, levenshteins, rouges, bm25s, jaccards]).T
        y = np.array(scores)

        scores = []
        for i in range(1000):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=i
            )
            ensemble = RandomForestClassifier(random_state=i)
            ensemble.fit(X_train, y_train)
            score = ensemble.score(X_test, y_test)
            scores.append(score)
        return np.mean(scores)

    def fit_ensembling(self):
        pre_augmentation_score = self.get_ensembled_scores(
            self.corpus, self.embeddings_corpus
        )
        post_augmentation_score = self.get_ensembled_scores(
            self.augmented_corpus, self.embeddings_augmented_corpus
        )

        self.results["ensembled"] = post_augmentation_score / pre_augmentation_score


def main(
    dataset_name="mteb/trec-covid",
    num_examples=5,
    model_name="embed-english-v3.0",
    max_length=8192,
):
    exp_config = RetrievalExperimentConfig(
        dataset_name,
        num_examples,
        model_name,
        max_length,
    )
    exp_config.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="scientific_papers")
    parser.add_argument("--num_examples", type=int, default=5)
    parser.add_argument("--model_name", type=str, default="embed-english-v3.0")
    parser.add_argument("--max_length", type=int, default=8192)
    args = parser.parse_args()

    dataset_name = args.dataset_name
    num_examples = args.num_examples
    model_name = args.model_name
    max_length = args.max_length

    main(
        dataset_name,
        num_examples,
        model_name,
        max_length,
    )
